/**********************************************************************
  Copyright(c) 2022-2023 Arm Corporation All rights reserved.

  Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions
  are met:
    * Redistributions of source code must retain the above copyright
      notice, this list of conditions and the following disclaimer.
    * Redistributions in binary form must reproduce the above copyright
      notice, this list of conditions and the following disclaimer in
      the documentation and/or other materials provided with the
      distribution.
    * Neither the name of Arm Corporation nor the names of its
      contributors may be used to endorse or promote products derived
      from this software without specific prior written permission.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
**********************************************************************/
#ifndef SNOW3G_INTERNAL_H
#define SNOW3G_INTERNAL_H

#include <arm_neon.h>
#ifdef SAFE_PARAM
#include "include/error.h"
#endif

#define MAX_KEY_LEN (16)
#define SNOW3G_4_BYTES (4)
#define SNOW3G_8_BYTES (8)
#define SNOW3G_8_BITS (8)
#define SNOW3G_16_BYTES (16)
#define SNOW3G_16_BITS (16)

#define SNOW3G_BLOCK_SIZE (8)

#define SNOW3G_KEY_LEN_IN_BYTES (16) /* 128b */
#define SNOW3G_IV_LEN_IN_BYTES (16)  /* 128b */

#define SNOW3GCONSTANT (0x1b)

/* Range of input data for SNOW3G is from 1 to 2^32 bits */
#define SNOW3G_MIN_LEN 1
#define SNOW3G_MAX_BITLEN (UINT32_MAX)
#define SNOW3G_MAX_BYTELEN (UINT32_MAX / 8)

typedef union SafeBuffer {
        uint64_t b64;
        uint32_t b32[2];
        uint8_t b8[SNOW3G_8_BYTES];
} SafeBuf;

typedef struct snow3gKeyState1_s {
        /* 16 LFSR stages */
        uint32_t LFSR_S[16];
        /* 3 FSM states */
        uint32_t FSM_R1;
        uint32_t FSM_R2;
        uint32_t FSM_R3;
} DECLARE_ALIGNED(snow3gKeyState1_t, 16);

typedef struct snow3gKeyState4_s {
        /* 16 LFSR stages */
        uint32x4_t LFSR_X[16];
        /* 3 FSM states */
        uint32x4_t FSM_X[3];
        uint32_t iLFSR_X;
} snow3gKeyState4_t;

typedef struct snow3gKeyState8_s {
        /* 16 LFSR stages */
        uint32x4x2_t LFSR_X[16];
        /* 3 FSM states */
        uint32x4x2_t FSM_X[3];
        uint32_t iLFSR_X;
} snow3gKeyState8_t;

/**
 * @brief Finds minimum 32-bit value in an array
 * @return Min 32-bit value
 */
static inline uint32_t
length_find_min(const uint32_t *out_array, const size_t dim_array)
{
        size_t i;
        uint32_t min = 0;

        if (dim_array > 0)
                min  = out_array[0];

        for (i = 1; i < dim_array; i++)
                if (out_array[i] < min)
                        min = out_array[i];

        return min;
}

/**
 * @brief Subtracts \a subv from a vector of 32-bit words
 */
static inline void
length_sub(uint32_t *out_array, const size_t dim_array, const uint32_t subv)
{
        size_t i;

        for (i = 0; i < dim_array; i++)
                out_array[i] -= subv;
}

#ifdef SAFE_PARAM
/**
 * @brief Checks vector of length values against 0 and SNOW3G_MAX_BYTELEN values
 * @retval 0 incorrect length value found
 * @retval 1 all OK
 */
static inline uint32_t
length_check(const uint32_t *out_array, const size_t dim_array)
{
        size_t i;

        if (out_array == NULL) {
                imb_set_errno(NULL, IMB_ERR_CIPH_LEN);
                return 0;
        }

        for (i = 0; i < dim_array; i++) {
                if ((out_array[i] == 0) ||
                    (out_array[i] > SNOW3G_MAX_BYTELEN)) {
                        imb_set_errno(NULL, IMB_ERR_CIPH_LEN);
                        return 0;
                    }
        }

        return 1;
}

/**
 * @brief Checks vector of length values against 0 and SNOW3G_MAX_BYTELEN values
 * @retval 0 incorrect length value found
 * @retval 1 all OK
 */
static inline uint32_t
length64_check(const uint64_t *out_array, const size_t dim_array)
{
        size_t i;

        if (out_array == NULL) {
                imb_set_errno(NULL, IMB_ERR_CIPH_LEN);
                return 0;
        }

        for (i = 0; i < dim_array; i++) {
                if ((out_array[i] == 0) ||
                    (out_array[i] > SNOW3G_MAX_BYTELEN)) {
                        imb_set_errno(NULL, IMB_ERR_CIPH_LEN);
                        return 0;
                    }
        }

        return 1;
}
#endif

/**
 * @brief Copies 4 32-bit length values into an array
 */
static inline void
length_copy_4(uint32_t *out_array,
              const uint32_t length1, const uint32_t length2,
              const uint32_t length3, const uint32_t length4)
{
        out_array[0] = length1;
        out_array[1] = length2;
        out_array[2] = length3;
        out_array[3] = length4;
}

/**
 * @brief Copies 8 32-bit length values into an array
 */
static inline void
length_copy_8(uint32_t *out_array,
              const uint32_t length1, const uint32_t length2,
              const uint32_t length3, const uint32_t length4,
              const uint32_t length5, const uint32_t length6,
              const uint32_t length7, const uint32_t length8)
{
        out_array[0] = length1;
        out_array[1] = length2;
        out_array[2] = length3;
        out_array[3] = length4;
        out_array[4] = length5;
        out_array[5] = length6;
        out_array[6] = length7;
        out_array[7] = length8;
}

#ifdef SAFE_PARAM
/**
 * @brief Checks vector of pointers against NULL
 * @retval 0 incorrect pointer found
 * @retval 1 all OK
 */
static inline int
ptr_check(void *out_array[], const size_t dim_array, const int errnum)
{
        size_t i;

        if (out_array == NULL) {
                imb_set_errno(NULL, errnum);
                return 0;
        }
        for (i = 0; i < dim_array; i++)
                if (out_array[i] == NULL) {
                        imb_set_errno(NULL, errnum);
                        return 0;
                }
        return 1;
}
#endif

#ifdef SAFE_PARAM
/**
 * @brief Checks vector of const pointers against NULL
 * @retval 0 incorrect pointer found
 * @retval 1 all OK
 */
static inline int
cptr_check(const void * const out_array[],
           const size_t dim_array,
           const int errnum)
{
        size_t i;

        if (out_array == NULL) {
                imb_set_errno(NULL, errnum);
                return 0;
        }
        for (i = 0; i < dim_array; i++)
                if (out_array[i] == NULL) {
                        imb_set_errno(NULL, errnum);
                        return 0;
                }

        return 1;
}
#endif

/**
 * @brief Copies 4 pointers into an array
 */
static inline void
ptr_copy_4(void *out_array[],
           void *ptr1, void *ptr2, void *ptr3, void *ptr4)
{
        out_array[0] = ptr1;
        out_array[1] = ptr2;
        out_array[2] = ptr3;
        out_array[3] = ptr4;
}

/**
 * @brief Copies 4 const pointers into an array
 */
static inline void
cptr_copy_4(const void *out_array[],
            const void *ptr1, const void *ptr2,
            const void *ptr3, const void *ptr4)
{
        out_array[0] = ptr1;
        out_array[1] = ptr2;
        out_array[2] = ptr3;
        out_array[3] = ptr4;
}

/**
 * @brief Copies 8 pointers into an array
 */
static inline void
ptr_copy_8(void *out_array[],
           void *ptr1, void *ptr2, void *ptr3, void *ptr4,
           void *ptr5, void *ptr6, void *ptr7, void *ptr8)
{
        out_array[0] = ptr1;
        out_array[1] = ptr2;
        out_array[2] = ptr3;
        out_array[3] = ptr4;
        out_array[4] = ptr5;
        out_array[5] = ptr6;
        out_array[6] = ptr7;
        out_array[7] = ptr8;
}

/**
 * @brief Copies 8 const pointers into an array
 */
static inline void
cptr_copy_8(const void *out_array[],
            const void *ptr1, const void *ptr2,
            const void *ptr3, const void *ptr4,
            const void *ptr5, const void *ptr6,
            const void *ptr7, const void *ptr8)
{
        out_array[0] = ptr1;
        out_array[1] = ptr2;
        out_array[2] = ptr3;
        out_array[3] = ptr4;
        out_array[4] = ptr5;
        out_array[5] = ptr6;
        out_array[6] = ptr7;
        out_array[7] = ptr8;
}

#endif /* SNOW3G_INTERNAL_H */
